Strassen’s Matrix multiplication

Video Lecture

Brute Force Approach


Brute Force Algorithm

  1. Initialization: The function void multiply(int A[][N], int B[][N], int C[][N]) takes three N x N matrices A, B, and C. The goal is to compute the product of matrices A and B, storing the result in matrix C.

  2. Nested Loops: The algorithm uses three nested loops:
    • Outer loop (i): Iterates over the rows of A.

    • Middle loop (j): Iterates over the columns of B.

    • Inner loop (k): Computes the dot product of the i-th row of A and the j-th column of B.

  3. Element Computation: For each element C[i][j]:
    • Initialize C[i][j] to 0.

    • Update it with the sum of products A[i][k] * B[k][j] for all k from 0 to N-1.

  4. Result: The final matrix C contains the product of matrices A and B.

  5. Time Complexity: The time complexity of this brute force approach is O(N3), involving N3 multiplications and N3 additions.

Divide and Conquer Approach

Input: Matrices A and B are given square matrices with size n x n, where n is a power of 2 (e.g., 1x1, 2x2, 4x4, 8x8, 16x16, 32x32, etc.).
Output: The resultant matrix C = A * B, where C is a square matrix of size n x n.
Process:
If n = 2, apply Strassen's Matrix Multiplication Algorithm to compute the elements of matrix C:
M1 = (A11 + A22) * (B11 + B22)
M2 = (A21 + A22) * B11
M3 = A11 * (B12 - B22)
M4 = A22 * (B21 - B11)
M5 = (A11 + A12) * B22
M6 = (A21 - A11) * (B11 + B12)
M7 = (A12 - A22) * (B21 + B22)
Then, compute the submatrices of C:
C11 = M1 + M4 - M5 + M7
C12 = M3 + M5
C21 = M2 + M4
C22 = M1 - M2 + M3 + M6
If n > 2, apply the Divide and Conquer method:
Divide matrices A and B into 8 submatrices, each of size n/2 x n/2.
Recursively compute the submatrices of C by multiplying the corresponding submatrices of A and B.
Result: Combine the computed submatrices to form the final matrix C.
Performance: The Divide and Conquer approach is faster than the standard brute force matrix multiplication algorithm for large matrices, while the brute force approach is more efficient for small matrices.


Strassen’s Matrix Multiplication Algorithm

  1. Base Case (n = 1):
    • If n = 1, then the result matrix C11 is calculated as C11 = A11 * B11.

  2. Case for n = 2:
    • If n = 2, compute the elements of the result matrix as:
      • C11 = M1 + M4 − M5 + M7

      • C12 = M3 + M5

      • C21 = M2 + M4

      • C22 = M1 − M2 + M3 + M6

  3. Recursive Case (n > 2):
    • If n > 2, divide the matrices into submatrices and recursively compute the submatrices as follows:
      • SMM(A11, B11, n/2) * SMM(A12, B21, n/2)

      • SMM(A11, B12, n/2) * SMM(A12, B22, n/2)

      • SMM(A21, B11, n/2) * SMM(A22, B21, n/2)

      • SMM(A21, B12, n/2) * SMM(A22, B22, n/2)

  4. End: The result matrix is computed based on the above cases.

Strassen's Matrix Multiplication Code

                    
                      
                        #include <stdio.h>
                        #include<stdio.h>

                            int main() {
                                int a[2][2], b[2][2], c[2][2], i, j;
                                int m1, m2, m3, m4, m5, m6, m7;
                            
                                // Input first matrix
                                printf("Enter the 4 elements of the first matrix: ");
                                for (i = 0; i < 2; i++) {
                                    for (j = 0; j < 2; j++) {
                                        scanf("%d", &a[i][j]);
                                    }
                                }
                            
                                // Input second matrix
                                printf("Enter the 4 elements of the second matrix: ");
                                for (i = 0; i < 2; i++) {
                                    for (j = 0; j < 2; j++) {
                                        scanf("%d", &b[i][j]);
                                    }
                                }
                            
                                // Display first matrix
                                printf("\nThe first matrix is\n");
                                for (i = 0; i < 2; i++) {
                                    for (j = 0; j < 2; j++) {
                                        printf("%d\t", a[i][j]);
                                    }
                                    printf("\n");
                                }
                            
                                // Display second matrix
                                printf("\nThe second matrix is\n");
                                for (i = 0; i < 2; i++) {
                                    for (j = 0; j < 2; j++) {
                                        printf("%d\t", b[i][j]);
                                    }
                                    printf("\n");
                                }
                            
                                // Strassen's formula calculations
                                m1 = (a[0][0] + a[1][1]) * (b[0][0] + b[1][1]);
                                m2 = (a[1][0] + a[1][1]) * b[0][0];
                                m3 = a[0][0] * (b[0][1] - b[1][1]);
                                m4 = a[1][1] * (b[1][0] - b[0][0]);
                                m5 = (a[0][0] + a[0][1]) * b[1][1];
                                m6 = (a[1][0] - a[0][0]) * (b[0][0] + b[0][1]);
                                m7 = (a[0][1] - a[1][1]) * (b[1][0] + b[1][1]);
                            
                                // Calculating elements of result matrix c
                                c[0][0] = m1 + m4 - m5 + m7;
                                c[0][1] = m3 + m5;
                                c[1][0] = m2 + m4;
                                c[1][1] = m1 - m2 + m3 + m6;
                            
                                // Display result matrix
                                printf("\nThe resultant matrix after multiplication is\n");
                                for (i = 0; i < 2; i++) {
                                    for (j = 0; j < 2; j++) {
                                        printf("%d\t", c[i][j]);
                                    }
                                    printf("\n");
                                }
                            
                                return 0;
                            }
                            
                            
                            
                    
                
                    
                        
                        #include <iostream>
                            #include <vector>
                            
                            using namespace std;
                            
                            typedef vector<vector<int>> Matrix;
                            
                            // Function to add two matrices
                            Matrix add(const Matrix &A, const Matrix &B) {
                                int n = A.size();
                                Matrix C(n, vector<int>(n));
                                for (int i = 0; i < n; i++) {
                                    for (int j = 0; j < n; j++) {
                                        C[i][j] = A[i][j] + B[i][j];
                                    }
                                }
                                return C;
                            }
                            
                            // Function to subtract two matrices
                            Matrix subtract(const Matrix &A, const Matrix &B) {
                                int n = A.size();
                                Matrix C(n, vector<int>(n));
                                for (int i = 0; i < n; i++) {
                                    for (int j = 0; j < n; j++) {
                                        C[i][j] = A[i][j] - B[i][j];
                                    }
                                }
                                return C;
                            }
                            
                            // Function to multiply two matrices using Strassen's algorithm
                            Matrix strassen(const Matrix &A, const Matrix &B) {
                                int n = A.size();
                                if (n == 1) {
                                    Matrix C(1, vector<int>(1));
                                    C[0][0] = A[0][0] * B[0][0];
                                    return C;
                                }
                            
                                int newSize = n / 2;
                                Matrix A11(newSize, vector<int>(newSize));
                                Matrix A12(newSize, vector<int>(newSize));
                                Matrix A21(newSize, vector<int>(newSize));
                                Matrix A22(newSize, vector<int>(newSize));
                                Matrix B11(newSize, vector<int>(newSize));
                                Matrix B12(newSize, vector<int>(newSize));
                                Matrix B21(newSize, vector<int>(newSize));
                                Matrix B22(newSize, vector<int>(newSize));
                            
                                // Dividing matrices into 4 sub-matrices
                                for (int i = 0; i < newSize; i++) {
                                    for (int j = 0; j < newSize; j++) {
                                        A11[i][j] = A[i][j];
                                        A12[i][j] = A[i][j + newSize];
                                        A21[i][j] = A[i + newSize][j];
                                        A22[i][j] = A[i + newSize][j + newSize];
                                        B11[i][j] = B[i][j];
                                        B12[i][j] = B[i][j + newSize];
                                        B21[i][j] = B[i + newSize][j];
                                        B22[i][j] = B[i + newSize][j + newSize];
                                    }
                                }
                            
                                Matrix M1 = strassen(add(A11, A22), add(B11, B22));
                                Matrix M2 = strassen(add(A21, A22), B11);
                                Matrix M3 = strassen(A11, subtract(B12, B22));
                                Matrix M4 = strassen(A22, subtract(B21, B11));
                                Matrix M5 = strassen(add(A11, A12), B22);
                                Matrix M6 = strassen(subtract(A21, A11), add(B11, B12));
                                Matrix M7 = strassen(subtract(A12, A22), add(B21, B22));
                            
                                Matrix C11 = add(subtract(add(M1, M4), M5), M7);
                                Matrix C12 = add(M3, M5);
                                Matrix C21 = add(M2, M4);
                                Matrix C22 = add(subtract(add(M1, M3), M2), M6);
                            
                                Matrix C(n, vector<int>(n));
                                for (int i = 0; i < newSize; i++) {
                                    for (int j = 0; j < newSize; j++) {
                                        C[i][j] = C11[i][j];
                                        C[i][j + newSize] = C12[i][j];
                                        C[i + newSize][j] = C21[i][j];
                                        C[i + newSize][j + newSize] = C22[i][j];
                                    }
                                }
                            
                                return C;
                            }
                            
                            void printMatrix(const Matrix &M) {
                                int n = M.size();
                                for (int i = 0; i < n; i++) {
                                    for (int j = 0; j < n; j++) {
                                        cout << M[i][j] << " ";
                                    }
                                    cout << endl;
                                }
                            }
                            
                            int main() {
                                int n = 4; // Size of the matrix
                                Matrix A(n, vector<int>(n, 0));
                                Matrix B(n, vector<int>(n, 0));
                                
                                // Fill matrices A and B with some values here
                                // Example:
                                // A[0][0] = 1; B[0][0] = 1;
                            
                                Matrix C = strassen(A, B);
                                cout << "Product Matrix:" << endl;
                                printMatrix(C);
                            
                                return 0;
                            }
                            
                    
                
                    
                        public class StrassenMatrixMultiplication {

                            // Function to add two matrices
                            public static int[][] add(int[][] A, int[][] B) {
                                int n = A.length;
                                int[][] C = new int[n][n];
                                for (int i = 0; i < n; i++) {
                                    for (int j = 0; j < n; j++) {
                                        C[i][j] = A[i][j] + B[i][j];
                                    }
                                }
                                return C;
                            }
                        
                            // Function to subtract two matrices
                            public static int[][] subtract(int[][] A, int[][] B) {
                                int n = A.length;
                                int[][] C = new int[n][n];
                                for (int i = 0; i < n; i++) {
                                    for (int j = 0; j < n; j++) {
                                        C[i][j] = A[i][j] - B[i][j];
                                    }
                                }
                                return C;
                            }
                        
                            // Function to multiply two matrices using Strassen's algorithm
                            public static int[][] strassen(int[][] A, int[][] B) {
                                int n = A.length;
                                if (n == 1) {
                                    int[][] C = new int[1][1];
                                    C[0][0] = A[0][0] * B[0][0];
                                    return C;
                                }
                        
                                int newSize = n / 2;
                                int[][] A11 = new int[newSize][newSize];
                                int[][] A12 = new int[newSize][newSize];
                                int[][] A21 = new int[newSize][newSize];
                                int[][] A22 = new int[newSize][newSize];
                                int[][] B11 = new int[newSize][newSize];
                                int[][] B12 = new int[newSize][newSize];
                                int[][] B21 = new int[newSize][newSize];
                                int[][] B22 = new int[newSize][newSize];
                        
                                // Dividing matrices into 4 sub-matrices
                                for (int i = 0; i < newSize; i++) {
                                    for (int j = 0; j < newSize; j++) {
                                        A11[i][j] = A[i][j];
                                        A12[i][j] = A[i][j + newSize];
                                        A21[i][j] = A[i + newSize][j];
                                        A22[i][j] = A[i + newSize][j + newSize];
                                        B11[i][j] = B[i][j];
                                        B12[i][j] = B[i][j + newSize];
                                        B21[i][j] = B[i + newSize][j];
                                        B22[i][j] = B[i + newSize][j + newSize];
                                    }
                                }
                        
                                int[][] M1 = strassen(add(A11, A22), add(B11, B22));
                                int[][] M2 = strassen(add(A21, A22), B11);
                                int[][] M3 = strassen(A11, subtract(B12, B22));
                                int[][] M4 = strassen(A22, subtract(B21, B11));
                                int[][] M5 = strassen(add(A11, A12), B22);
                                int[][] M6 = strassen(subtract(A21, A11), add(B11, B12));
                                int[][] M7 = strassen(subtract(A12, A22), add(B21, B22));
                        
                                int[][] C11 = add(subtract(add(M1, M4), M5), M7);
                                int[][] C12 = add(M3, M5);
                                int[][] C21 = add(M2, M4);
                                int[][] C22 = add(subtract(add(M1, M3), M2), M6);
                        
                                int[][] C = new int[n][n];
                                for (int i = 0; i < newSize; i++) {
                                    for (int j = 0; j < newSize; j++) {
                                        C[i][j] = C11[i][j];
                                        C[i][j + newSize] = C12[i][j];
                                        C[i + newSize][j] = C21[i][j];
                                        C[i + newSize][j + newSize] = C22[i][j];
                                    }
                                }
                        
                                return C;
                            }
                        
                            // Function to print the matrix
                            public static void printMatrix(int[][] matrix) {
                                int n = matrix.length;
                                for (int i = 0; i < n; i++) {
                                    for (int j = 0; j < n; j++) {
                                        System.out.print(matrix[i][j] + " ");
                                    }
                                    System.out.println();
                                }
                            }
                        
                            public static void main(String[] args) {
                                int n = 4; // Size of the matrix
                                int[][] A = new int[n][n];
                                int[][] B = new int[n][n];
                        
                                // Fill matrices A and B with some values here
                                // Example:
                                // A[0][0] = 1; B[0][0] = 1;
                        
                                int[][] C = strassen(A, B);
                                System.out.println("Product Matrix:");
                                printMatrix(C);
                            }
                        }
                        
                    
                
                    
                        import numpy as np

                        def add(A, B):
                            return np.add(A, B)
                        
                        def subtract(A, B):
                            return np.subtract(A, B)
                        
                        def strassen(A, B):
                            n = A.shape[0]
                            if n == 1:
                                return A * B
                        
                            new_size = n // 2
                            A11 = A[:new_size, :new_size]
                            A12 = A[:new_size, new_size:]
                            A21 = A[new_size:, :new_size]
                            A22 = A[new_size:, new_size:]
                            B11 = B[:new_size, :new_size]
                            B12 = B[:new_size, new_size:]
                            B21 = B[new_size:, :new_size]
                            B22 = B[new_size:, new_size:]
                        
                            M1 = strassen(add(A11, A22), add(B11, B22))
                            M2 = strassen(add(A21, A22), B11)
                            M3 = strassen(A11, subtract(B12, B22))
                            M4 = strassen(A22, subtract(B21, B11))
                            M5 = strassen(add(A11, A12), B22)
                            M6 = strassen(subtract(A21, A11), add(B11, B12))
                            M7 = strassen(subtract(A12, A22), add(B21, B22))
                        
                            C11 = add(subtract(add(M1, M4), M5), M7)
                            C12 = add(M3, M5)
                            C21 = add(M2, M4)
                            C22 = add(subtract(add(M1, M3), M2), M6)
                        
                            C = np.zeros((n, n), dtype=A.dtype)
                            C[:new_size, :new_size] = C11
                            C[:new_size, new_size:] = C12
                            C[new_size:, :new_size] = C21
                            C[new_size:, new_size:] = C22
                        
                            return C
                        
                        # Example usage
                        A = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]])
                        B = np.array([[16, 15, 14, 13], [12, 11, 10, 9], [8, 7, 6, 5], [4, 3, 2, 1]])
                        
                        C = strassen(A, B)
                        print("Product Matrix:")
                        print(C)
                        
                    
                

Analysis of Algorithm

Time Complexity of Divide and Conquer Approach


For n × n matrix multiplication using Strassen’s algorithm, the time complexity is derived based on the divide and conquer approach.


Key Points:


  • Base Case: If n = 1, the algorithm performs 1 multiplication. So, T(1) = 1.

  • Recursive Case: For n = 2k, where k = log n:
    • The algorithm makes 7 recursive calls, each operating on a subproblem of size n/2.

    • The recurrence relation is T(n) = 7T(n/2) + cn2, where cn2 is the time complexity for matrix addition.

Deriving the Time Complexity:


  1. Start with the recurrence relation:
    T(n) = 7T(n/2)

  2. Expand it by substituting recursively:
    
                        T(n) = 7(7T(n/4)) = 72T(n/4)
    T(n) = 7kT(n/2k)
  3. Since n = 2k and T(1) = 1, the final expansion is:
    T(n) = 7log n ⋅ T(1) = nlog2 7

    Here, log2 7 ≈ 2.808.


Conclusion:


The time complexity of Strassen’s algorithm is:


T(n) = O(n2.808)

This is an improvement over the standard matrix multiplication time complexity of O(n3).


Time Complexity Analysis (Expanded)


  • Recurrence Relation:
    T(n) = 7T(n/2) + cn2

    where cn2 accounts for the matrix addition operations.


  • Master Theorem Application:

    Using the Master Theorem, the time complexity derived is O(nlog2 7) = O(n2.808).


Summary:


Strassen’s algorithm reduces the time complexity of matrix multiplication from O(n3) to O(n2.808), making it more efficient for large matrices.